rm(list = ls())
library(tidyverse)
library(estimatr)
library(grid)
library(ri2)
library(reshape2)
library(bandit)
library(stargazer)
setwd(dirname(rstudioapi::getActiveDocumentContext()$path))
set.seed(10027)
source('fn.R')

load('data.rdata')

mod  <- lm_robust(Y ~ arm - 1, data = data, weights = ipw)
true <- coef(mod)
iter <- 10000

true_param <- list(top7 = sort(true)[5:11],
                   top5   = sort(true)[7:11], 
                   top3   = sort(true)[9:11])

if (file.exists('Output/TF4_simulation.rdata')){
  load('Output/TF4_simulation.rdata')
}else {
  simlist <- list()
  for (i in names(true_param)){
    print(i)
    simlist[[i]] <- simulate(periods = 10, n = 200, probs = true_param[[i]], iter = iter, static = FALSE, ppmat = TRUE)
  }
  # save(simlist, file = 'Output/TF4_simulation.rdata')
}

sumlist <- list()
for (i in names(true_param)){
  m <- tidy(mod)[,1:3] %>% transmute(term, true = estimate, true_se = std.error) %>%
    filter(term %in% names(true_param[[i]]))
  
  x <- as.data.frame(simlist[[i]][['d_fit']]) 
  zval <- unique(x$term)
  zval_true <- names(true_param[[i]])
  x$term1 <- NA
  for (j in 1:length(zval)) x$term1[x$term == zval[j]] <- zval_true[j]
                   
  sumlist[[i]] <- x %>%
    left_join(m, by = c('term1' = 'term')) %>%
    group_by(iter) %>%
    mutate(best = max(posterior) == posterior) %>%
    group_by(term1, term) %>%
    summarize(true = mean(true, na.rm = T),
              true_se = mean(true_se, na.rm = T),
              best = mean(best, na.rm = T),
              est  = mean(estimate, na.rm = T),
              se   = mean(std.error, na.rm = T),
              bias = mean(true - estimate, na.rm = T),
              rmse = sqrt(mean((estimate - true)^2, na.rm = T)),
              coverage = mean(conf.low < true & conf.high > true, na.rm = T)) %>%
    arrange(-true) %>%
    ungroup %>%
    mutate_if(is.numeric, ~sprintf('%0.3f', .x)) %>%
    mutate(term = gsub('arm|zvals', 'Arm ', term1),
           true = paste0(true, ' (', true_se, ')'),
           est = paste0(est, ' (', se, ')')) %>%
    dplyr::select(-true_se, -se, -term1, -est, -bias)
}

result <- bind_rows('Top 7' = sumlist$top7,
                    'Top 5' = sumlist$top5,
                    'Top 3' = sumlist$top3, .id = 'Simulation')

stargazer::stargazer(result, summary = F, rownames = F)
